package linc.fun.openai.handler.emitter;

import cn.hutool.core.util.StrUtil;
import com.mybatisflex.core.query.QueryWrapper;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.utils.TikTokensUtil;
import jakarta.annotation.Resource;
import linc.fun.openai.config.openai.OpenaiConfig;
import linc.fun.openai.domain.dto.request.ChatProcessRequest;
import linc.fun.openai.domain.entity.chat.ChatMessageDO;
import linc.fun.openai.domain.entity.chat.ChatUserDO;
import linc.fun.openai.domain.vo.ChatReplyMessageVO;
import linc.fun.openai.enums.*;
import linc.fun.openai.exception.BizException;
import linc.fun.openai.openai.apikey.ApiKeyChatClientBuilder;
import linc.fun.openai.openai.listener.ParsedEventSourceListener;
import linc.fun.openai.openai.listener.ResponseBodyEmitterStreamListener;
import linc.fun.openai.openai.parser.ChatCompletionResponseParser;
import linc.fun.openai.openai.storage.ApiKeyDatabaseDataStorage;
import linc.fun.openai.service.ChatMessageService;
import linc.fun.openai.service.ChatUserService;
import linc.fun.openai.util.ChatUserUtil;
import linc.fun.openai.util.ObjectMapperUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;

import java.io.IOException;
import java.time.LocalDateTime;
import java.util.LinkedList;
import java.util.Objects;

import static linc.fun.openai.domain.entity.chat.table.Tables.ChatMessage;
import static linc.fun.openai.domain.entity.chat.table.Tables.ChatUser;


/**
 * @author linc
 * @date 2023-3-24
 * ApiKey 响应处理
 */
@Slf4j
@Component
public class ApiKeyResponseEmitter implements ResponseEmitter {

    @Resource
    private OpenaiConfig openaiConfig;

    @Resource
    private ChatMessageService chatMessageService;

    @Resource
    private ChatCompletionResponseParser parser;

    @Resource
    private ApiKeyDatabaseDataStorage dataStorage;
    @Resource
    private ChatUserService chatUserService;

    @Override
    public void requestToResponseEmitter(ChatProcessRequest chatProcessRequest, ResponseBodyEmitter emitter) {

        // 查询当前登陆人
        ChatUserDO user = chatUserService.getById(ChatUserUtil.getUserId());
        if (Objects.isNull(user)) {
            throw BizException.CURRENT_USER_NOT_EXIST;
        }

        // 初始化聊天消息
        ChatMessageDO chatMessageDO = chatMessageService.initChatMessage(chatProcessRequest, ApiTypeEnum.API_KEY);
        // 所有消息
        LinkedList<Message> messages = new LinkedList<>();
        // 添加用户上下文消息
        addContextChatMessage(chatMessageDO, messages);

        log.info("requestToResponseEmitter，请求原消息内容: {}", ObjectMapperUtil.toJson(messages));
        // token数量校验
        // 获取 包含上下文 的 token 数量
        int totalTokenCount = TikTokensUtil.tokens(chatMessageDO.getModelName(), messages);
        while (totalTokenCount > 3000 && messages.size() > 3) {
            // 超出长度，移除最早的两条消息
            messages.removeFirst();
            messages.removeFirst();
            totalTokenCount = TikTokensUtil.tokens(ChatConversationModelEnum.GPT_3_5_TURBO.getName(), messages);
        }
        log.info("requestToResponseEmitter，请求原消息调整后内容: {}", ObjectMapperUtil.toJson(messages));
        // 最终检查，如果长度仍然过长，直接拒绝请求
        if (totalTokenCount > 3000) {
            throw BizException.QUESTION_TOO_LONG;
        }

        // 系统角色消息
        if (StrUtil.isNotBlank(chatProcessRequest.getSystemMessage())) {
            // 系统消息
            Message systemMessage = Message.builder()
                    .role(Message.Role.SYSTEM)
                    .content(chatProcessRequest.getSystemMessage())
                    .build();
            messages.addFirst(systemMessage);
        }

        // 获取 包含上下文 的 token 数量
        totalTokenCount = TikTokensUtil.tokens(chatMessageDO.getModelName(), messages);
        // 设置 promptTokens
        chatMessageDO.setPromptTokens(totalTokenCount);

        // 检查 tokenCount 是否超出当前模型的 Token 数量限制
        String exceedModelTokenLimitMsg = exceedModelTokenLimit(chatProcessRequest, chatMessageDO.getModelName(), totalTokenCount, emitter);
        if (Objects.nonNull(exceedModelTokenLimitMsg)) {
            chatMessageDO.setStatus(ChatMessageStatusEnum.EXCEPTION_TOKEN_EXCEED_LIMIT);
            chatMessageDO.setResponseErrorData(exceedModelTokenLimitMsg);
            chatMessageService.updateById(chatMessageDO);
            return;
        }

        int maxTokens = ApiKeyTokenLimiterEnum.getTokenLimitByOuterJarModelName(openaiConfig.getApiModel()) - totalTokenCount - 1;
        // 构建聊天参数
        ChatCompletion chatCompletion = ChatCompletion.builder()
                // 最大的 tokens = 模型的最大上线 - 本次 prompt 消耗的 tokens
                .maxTokens(maxTokens)
                .model(openaiConfig.getApiModel())
                // [0, 2] 越低越精准
                .temperature(0.8)
                .topP(1.0)
                // 每次生成一条
                .n(1)
                .presencePenalty(1)
                .messages(messages)
                .stream(true)
                .build();

        log.info("requestToResponseEmitter==>最大限制tokens: {},当前请求tokens: {}", maxTokens, chatCompletion.tokens());


        // 构建事件监听器
        ParsedEventSourceListener parsedEventSourceListener = new ParsedEventSourceListener.Builder()
//                .addListener(new ConsoleStreamListener())
                .addListener(new ResponseBodyEmitterStreamListener(emitter))
                .setParser(parser)
                .setDataStorage(dataStorage)
                .setOriginalRequestData(ObjectMapperUtil.toJson(chatCompletion))
                .setChatMessageDO(chatMessageDO)
                .build();

        ApiKeyChatClientBuilder.buildOpenAiStreamClient().streamChatCompletion(chatCompletion, parsedEventSourceListener);
        // tod：{
        //  "error": {
        //    "message": "That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 9159ae5dc6a9190b7aa84201c85f1461 in your message.)",
        //    "type": "server_error",
        //    "param": null,
        //    "code": null
        //  }
        //}
        user.setCanConversationTimes(user.getCanConversationTimes() - 1);
        user.setUpdateTime(LocalDateTime.now());
        boolean ok = chatUserService.update(user, QueryWrapper.create()
                .where(ChatUser.Id.eq(user.getId()))
                .and(ChatUser.CanConversationTimes.ge(0))
        );
        if (!ok) {
            log.error("扣减聊天次数失败，用户信息：{}", ObjectMapperUtil.toJson(user));
        }

    }

    /**
     * 检查上下文消息的 Token 数是否超出模型限制
     *
     * @param chatProcessRequest 对话请求
     * @param modelName          当前使用的模型名称
     * @param tokenCount         当前上下的总 Token 数量
     * @param emitter            ResponseBodyEmitter
     */
    private String exceedModelTokenLimit(ChatProcessRequest chatProcessRequest, String modelName, int tokenCount, ResponseBodyEmitter emitter) {
        // 当前模型最大 tokens
        int maxTokens = ApiKeyTokenLimiterEnum.getTokenLimitByOuterJarModelName(modelName);

        String msg;
        // 判断 token 数量是否超过限制
        if (ApiKeyTokenLimiterEnum.exceedsLimit(modelName, tokenCount)) {
            // 获取当前 prompt 消耗的 tokens
            int currentPromptTokens = TikTokensUtil.tokens(modelName, chatProcessRequest.getPrompt());
            // 判断历史上下文是否超过限制
            int remainingTokens = tokenCount - currentPromptTokens;
            if (ApiKeyTokenLimiterEnum.exceedsLimit(modelName, remainingTokens)) {
                msg = "当前上下文字数已经达到上限，请关闭上下文或开启新的对话";
            } else {
                msg = StrUtil.format("当前上下文 Token 数量：{}，超过上限：{}，请减少字数发送或关闭上下文或开启新的对话", tokenCount, maxTokens);
            }
        }
        // 剩余的 token 太少也直返返回异常信息
        else if (maxTokens - tokenCount <= 10) {
            msg = "当前上下文字数不足以连续对话，请关闭上下文或开启新的对话";
        } else {
            return null;
        }

        try {
            ChatReplyMessageVO chatReplyMessageVO = ChatReplyMessageVO.onEmitterChainException(chatProcessRequest);
            chatReplyMessageVO.setText(msg);
            emitter.send(ObjectMapperUtil.toJson(chatReplyMessageVO));
        } catch (IOException e) {
            throw new RuntimeException(e);
        } finally {
            emitter.complete();
        }
        return msg;
    }

    /**
     * 添加上下文问题消息
     *
     * @param chatMessageDO 当前消息
     * @param messages      消息列表
     */
    private void addContextChatMessage(ChatMessageDO chatMessageDO, LinkedList<Message> messages) {
        if (Objects.isNull(chatMessageDO)) {
            return;
        }
        // 父级消息id为空，表示是第一条消息，直接添加到message里
        if (Objects.isNull(chatMessageDO.getParentMessageId())) {
            messages.addFirst(Message.builder().role(Message.Role.USER)
                    .content(chatMessageDO.getContent())
                    .build());
            return;
        }

        // 根据消息类型去选择角色，需要添加问题和回答到上下文
        Message.Role role = (chatMessageDO.getMessageType() == ChatMessageTypeEnum.ANSWER) ?
                Message.Role.ASSISTANT : Message.Role.USER;

        // 回答不成功的情况下，不添加回答消息记录和该回答的问题消息记录
        if (chatMessageDO.getMessageType() == ChatMessageTypeEnum.ANSWER
                && chatMessageDO.getStatus() != ChatMessageStatusEnum.PART_SUCCESS
                && chatMessageDO.getStatus() != ChatMessageStatusEnum.COMPLETE_SUCCESS) {
            // 没有父级回答消息直接跳过
            if (Objects.isNull(chatMessageDO.getParentAnswerMessageId())) {
                return;
            }
            ChatMessageDO parentMessage = chatMessageService.getOne(QueryWrapper.create()
                    .where(ChatMessage.MessageId.eq(chatMessageDO.getParentAnswerMessageId())));
            this.addContextChatMessage(parentMessage, messages);
            return;
        }

        // 从下往上找并添加，越上面的数据放越前面
        messages.addFirst(Message.builder().role(role)
                .content(chatMessageDO.getContent())
                .build());

        ChatMessageDO parentMessage = chatMessageService.getOne(QueryWrapper.create()
                .where(ChatMessage.MessageId.eq(chatMessageDO.getParentMessageId())));

        this.addContextChatMessage(parentMessage, messages);
    }
}
